import copy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

from collections import OrderedDict
import time
import matplotlib.pyplot as plt
import torchvision

from util import *

from stargan import Generator, Discriminator, load_stargan

import sys
sys.path.append('')
sys.path.append('')
import warnings
from iaf import IterAlignFlow
from ddl.base import (BoundaryWarning, DataConversionWarning)

from autoencoders.ae_model import AE
warnings.simplefilter('ignore', BoundaryWarning) # Ignore boundary warnings from ddl
warnings.simplefilter('ignore', DataConversionWarning) # Ignore data conversion warnings from ddl


class Client(object):
    def __init__(self, loader, config):

        self.loader = loader
        self.config = config
        self.device = config.device

        self._build_model()

    def _build_model(self):
        if self.config.trans != 'stargan':
            self.model = FedDIRT(self.config).to(self.device)
        else:
            self.model = FedDIRTStarGAN(self.config).to(self.device)

        self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.lr)
        #self.optimizer = optim.SGD(self.model.parameters(), lr=self.config.lr)
        self.lossMeter = AverageMeter()
        self.accMeter = AverageMeter()

    def restore_optim(self):
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.lr)
        #self.optimizer = optim.SGD(self.model.parameters(), lr=self.config.lr)
        if self.optimizer_state != None:
            state = self.optimizer.state_dict()
            state['state'] = self.optimizer_state
            self.optimizer.load_state_dict(state)


    def train(self,fig=False):
        '''
        Update for one mini-batch
        '''


        try:
            x, y, d = next(self.data_iter)
        except:
            self.data_iter = iter(self.loader)
            x, y, d = next(self.data_iter)

        x, y, d = x.to(self.device), y.to(self.device), d.to(self.device)
        self.optimizer.zero_grad()
        loss, acc = self.model(x, y,d,fig)
        loss.backward()
        self.optimizer.step()
        self.optimizer_state = self.optimizer.state_dict()['state']

        with torch.no_grad():
            self.lossMeter.update(loss.item(), len(x))
            self.accMeter.update(acc, len(x))


class Central(object):
    def __init__(self, loader_dict, test_loader, config):

        self.config = config
        self.list_train_domains = config.list_train_domains
        self.device = config.device

        self.num_iters = config.iters
        print('total iters: ',self.num_iters)

        self.sync_step = config.sync_step
        self.eval_step = config.eval_step

        self._build_model()
        self._init_client(loader_dict)
        self.test_loader = test_loader

        self.nparams = 0
        self.model_size = check_nparams(self.model)


    def _build_model(self):
        if self.config.trans != 'stargan':
            self.model = FedDIRT(self.config).to(self.device)
        else:
            print(self.config.trans)
            self.model = FedDIRTStarGAN(self.config).to(self.device)

    def _init_client(self,loader_dict):
        clients_dict = dict()
        for domain in self.list_train_domains:
            clients_dict[domain] = Client(loader_dict[domain],self.config)
        self.clients_dict = clients_dict

    def _aggregate(self,coeffs=None):

        if not coeffs:
            coeffs = [1/len(self.list_train_domains) for _ in range(len(self.list_train_domains))]

        averaged_weights = OrderedDict()
        for i, domain in enumerate(self.list_train_domains):
            local_weight = self.clients_dict[domain].model.state_dict()
            for key in self.model.state_dict().keys():
                if i == 0:
                    averaged_weights[key] = coeffs[i] * local_weight[key]
                else:
                    averaged_weights[key] += coeffs[i] * local_weight[key]
        #print(averaged_weights.keys())
        self.model.load_state_dict(averaged_weights)

    def _transmit(self):
        for domain in self.list_train_domains:
            self.clients_dict[domain].model = copy.deepcopy(self.model)
            self.clients_dict[domain].model = self.clients_dict[domain].model.to(self.clients_dict[domain].device)
            self.clients_dict[domain].model.train()
            self.clients_dict[domain].restore_optim()



    def eval(self):
        self.model.eval()
        lossMeter = AverageMeter()
        accMeter = AverageMeter()
        for batch_idx, (x, y) in enumerate(self.test_loader):
            # To device
            x, y = x.to(self.device), y.to(self.device)

            loss, acc = self.model(x, y)

            lossMeter.update(loss.item(), len(x))
            accMeter.update(acc, len(x))

        return lossMeter, accMeter

    def _agg_train_loss(self):
        loss = 0
        acc = 0
        for domain in self.list_train_domains:
            loss += self.clients_dict[domain].lossMeter.value()
            acc += self.clients_dict[domain].accMeter.value()
            self.clients_dict[domain].lossMeter = AverageMeter()
            self.clients_dict[domain].accMeter = AverageMeter()
        loss = loss/len(self.list_train_domains)
        acc = acc/len(self.list_train_domains)

    def train(self):

        start_iters = 0

        loss_tracker, acc_tracker, np_tracker = [], [], []

        # Start training.
        print('Start training...')
        #start_time = time.time()
        for i in range(start_iters, self.num_iters):

            # =================================================================================== #
            # 1. Train local clients for each mini-batch                                          #
            # =================================================================================== #
            for domain in self.list_train_domains:
                if (i+1)%10 ==0:
                    self.clients_dict[domain].train(fig=True)
                else:
                    self.clients_dict[domain].train(fig=False)
            # =================================================================================== #
            # 2. Synchronize with central and each local client                                   #
            # =================================================================================== #
            if (i + 1) % self.sync_step == 0:

                # aggregate for central model
                self._aggregate()

                # transmit central model to each client
                self._transmit()

            # =================================================================================== #
            # 3. Evaluate model                                                                   #
            # =================================================================================== #

            if (i + 1) % self.eval_step == 0:
                self._agg_train_loss()
                loss, acc = self.eval()
                print(f'after {i + 1} iters, test loss: {loss}, test acc: {acc}')
                nparams = self.eval_step * 2 * self.model_size
                self.nparams += nparams
                with torch.no_grad():
                    loss_tracker.append(loss)
                    acc_tracker.append(acc)
                    np_tracker.append(self.nparams)

                tracker = dict()
                tracker['loss'] = loss_tracker
                tracker['acc'] = acc_tracker
                tracker['np'] = np_tracker
                save_name = self.config.trans.replace('/', '-') + self.config.note
                torch.save(tracker, f'./saved/{save_name}.pt')

        tracker = dict()
        tracker['loss'] = loss_tracker
        tracker['acc'] = acc_tracker
        tracker['np'] = np_tracker
        return tracker



class FedDIRT(nn.Module):
    def __init__(self,config):
        super(FedDIRT,self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, bias=False), nn.InstanceNorm2d(32), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=5, stride=1, bias=False), nn.InstanceNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2),
        )
        self.fc11 = nn.Sequential(nn.Linear(1024, 64))

        torch.nn.init.xavier_uniform_(self.encoder[0].weight)
        torch.nn.init.xavier_uniform_(self.encoder[4].weight)
        torch.nn.init.xavier_uniform_(self.fc11[0].weight)
        self.fc11[0].bias.data.zero_()

        self.cls = nn.Linear(64, 10)
        torch.nn.init.xavier_uniform_(self.cls.weight)
        self.cls.bias.data.zero_()


        # self.encoder = nn.Sequential(
        #     nn.Conv2d(1, 32, kernel_size=5, stride=1),  nn.ReLU(), nn.MaxPool2d(2, 2),
        #     nn.Conv2d(32, 64, kernel_size=5, stride=1), nn.ReLU(), nn.MaxPool2d(2, 2),
        # )
        # self.fc11 = nn.Sequential(nn.Linear(1024, 64))
        # self.cls = nn.Linear(64, 10)



        self.n_domains = config.n_domains
        self.use_shared = config.use_shared
        self.trans_name = config.tn

        self.extra = config.extra

        # =================================================================================== #
        #                                 (1) Whether use shared                              #
        # =================================================================================== #
        if self.use_shared:
            print('Use shared space !!!')
        # =================================================================================== #
        #                                 (2) Whether use AE                                  #
        # =================================================================================== #
        # if config.trans == 'indaeinb':
        #     self._load_indae(config)
        if config.tn == 'indaeinb':
            self._load_indae(config)

        # =================================================================================== #
        #                                 (3) Whether use Hist                                  #
        # =================================================================================== #
        # if config.hist_bins:
        #     trans_dir = 'hist' + config.trans
        #     print('Use HistINB')
        # else:
        #     trans_dir = 'torch' + config.trans
        #     print('Use INB')
        model_dir = config.model_dir
        # if config.hist_bins:
        #     model_dir += '/saved/{}/{}_{}_{}_{}/{}/inb.pt'.format(trans_dir,
        #                                                        config.nlayer,
        #                                                        config.k,
        #                                                        config.max_swd_iters,
        #                                                        config.hist_bins,
        #                                                        config.target_domain)
        # else:
        #     model_dir += '/saved/{}/{}_{}_{}/{}/inb.pt'.format(trans_dir,
        #                                                         config.nlayer,
        #                                                         config.k,
        #                                                         config.max_swd_iters,
        #                                                         config.target_domain)
        print(model_dir)
        self.trans = torch.load(model_dir,map_location=config.device)

        self.reg = config.reg
        self.device = config.device

    def forward(self,x,y,d=None, fig=False):
        if self.training:
            if self.extra:
                x,y,d = self._extra_data(x,y,d)


        h = self.encoder(x)
        h = h.view(-1, 1024)
        z = self.fc11(h)
        #print(y.shape)
        logits = self.cls(F.relu(z))
        loss = F.cross_entropy(logits, y)
        acc = ((logits.argmax(1)==y).sum().float()/len(y)).item()

        device = x.device
        if self.training:
            with torch.no_grad():
                #target_d = np.random.choice(5)
                if self.extra:
                    target_d = (torch.ones(x.shape[0]) * np.random.choice(5)).to(torch.int64).to(self.device)
                else:
                    target_d = torch.Tensor([np.random.choice(5) for _ in range(x.shape[0])]).to(torch.int64).to(self.device)
                if self.trans_name == 'inb':
                    if self.use_shared:
                        x_ = self._inb_findshare(x,y,d)
                    else:
                        x_ = self._inb(x,y,d,target_d)
                elif self.trans_name == 'indaeinb':
                        x_ = self._indaeinb(x,y,d,target_d)
                else:
                    pass

            if fig:
                grid_img = torchvision.utils.make_grid(x[:20].view(-1, 1, 28, 28).cpu(), nrow=10, normalize=True)
                grid_img = torchvision.utils.make_grid(x_[:20].view(-1, 1, 28, 28).cpu(), nrow=10, normalize=True)
            h_ = self.encoder(x_)
            h_ = h_.view(-1, 1024)
            z_ = self.fc11(h_)
            reg = F.mse_loss(z_,z,reduction='mean')
            loss = loss + self.reg * reg
        return loss, acc

    def _extra_data(self,x,y,d):
        de = torch.Tensor([np.random.choice(5) for _ in range(x.shape[0])]).to(torch.int64).to(self.device)
        if self.trans_name == 'inb':
            xe = self._inb(x, y, d, de)
        else:
            xe = self._indaeinb(x, y, d, de)

        x = torch.cat((x,xe),dim=0)
        d = torch.cat((d,de))
        y = torch.cat((y,y))
        return x,y,d


    def _inb(self,x,y,d,target_d):
        x_ = batch_inb_translate(self.trans, x.view(-1, 784), y, d, target_d).view(-1, 1, 28, 28)
        return x_


    def _inb_findshare(self,x,y,d):
        x_ = batch_inb_findshare(self.trans, x.view(-1, 784), y, d).view(-1, 1, 28, 28)
        return x_

    def _load_indae(self, args):
        self.enc = wrap_enc(args)
        self.dec = wrap_dec(args)
        return self


    def _indaeinb(self,x,y,d,target_d):
        x_ = batch_indaeinb_translate(self.trans, x.view(-1, 784), y, d,
                                 target_d,
                                 self.enc,self.dec,
                                 ).view(-1, 1, 28, 28)
        return x_


class FedDIRTStarGAN(nn.Module):
    def __init__(self,config):
        super(FedDIRTStarGAN,self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, bias=False), nn.InstanceNorm2d(32), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=5, stride=1, bias=False), nn.InstanceNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2),
        )
        self.fc11 = nn.Sequential(nn.Linear(1024, 64))

        torch.nn.init.xavier_uniform_(self.encoder[0].weight)
        torch.nn.init.xavier_uniform_(self.encoder[4].weight)
        torch.nn.init.xavier_uniform_(self.fc11[0].weight)
        self.fc11[0].bias.data.zero_()

        self.cls = nn.Linear(64, 10)
        torch.nn.init.xavier_uniform_(self.cls.weight)
        self.cls.bias.data.zero_()
        print('saved/stargan_model/{}_domain{}_last-G.ckpt'.format(config.dataset,config.target_domain))
        self.trans = load_stargan(ckpt=config.model_dir + 'stargan_model/{}_domain{}_last-G.ckpt'.format(config.dataset,75))
        self.trans.eval()

        self.device = config.device
        self.reg = config.reg

    def forward(self,x,y,d=None,fig=False):
        h = self.encoder(x)
        h = h.view(-1, 1024)
        z = self.fc11(h)

        logits = self.cls(F.relu(z))
        loss = F.cross_entropy(logits, y)
        acc = ((logits.argmax(1)==y).sum().float()/len(y)).item()


        if self.training:
            with torch.no_grad():
                one_hot_d = x.new_zeros([x.shape[0],5])
                one_hot_d.scatter_(1, d[:,None], 1)
                d_ = torch.Tensor([np.random.choice(5) for _ in range(x.shape[0])]).to(torch.int64).to(self.device)
                #d_ = x.new_ones(x.shape[0]).to(torch.int64)*np.random.choice(5)
                one_hot_d_ = x.new_zeros([x.shape[0],5])
                one_hot_d_.scatter_(1, d_[:,None], 1)
                x_ = self.trans(x,one_hot_d,one_hot_d_)

            h_ = self.encoder(x_)
            h_ = h_.view(-1, 1024)
            z_ = self.fc11(h_)
            reg = F.mse_loss(z_,z,reduction='mean')
            loss = loss + reg * self.reg

        return loss, acc


def inb_translate(cd, x, d, target_d):
    z = cd(x,d)
    #trans_d = torch.ones(z.shape[0]) * target_d
    x_trans = cd.inverse(z,target_d)
    return x_trans

def batch_inb_translate(cd_dict,x,y,d, target_d):

    classes = torch.unique(y)
    x_trans = torch.zeros_like(x).to(x.device)
    for yy in classes:
        xt = x[y==yy]
        dt = d[y==yy]
        tdt = target_d[y==yy]
        xt_trans = inb_translate(cd_dict[int(yy)], xt, dt, tdt)
        x_trans[y==yy] = xt_trans
    return x_trans




def indaeinb_translate(cd, x, d, target_d,enc,dec):
    x_enc = torch.zeros(x.shape[0],392).to(x.device)
    domains = torch.unique(d)
    for dd in domains:
        x_enc[d==dd] = enc(x[d==dd],dd)
    z = cd(x_enc,d)
    #trans_d = torch.ones(z.shape[0]) * target_d
    x_trans = cd.inverse(z,target_d)

    x_trans_dec = torch.zeros(x.shape[0],784).to(x.device)
    domains = torch.unique(target_d)
    for dd in domains:
        x_trans_dec[target_d==dd] = dec(x_trans[target_d==dd],dd)
    return x_trans_dec

def batch_indaeinb_translate(cd_dict,x,y,d, target_d, enc,dec):

    classes = torch.unique(y)
    x_trans = torch.zeros_like(x).to(x.device)
    for yy in classes:
        xt = x[y==yy]
        dt = d[y==yy]
        tdt = target_d[y==yy]

        xt_trans = indaeinb_translate(cd_dict[int(yy)], xt, dt, tdt,enc,dec)
        x_trans[y==yy] = xt_trans
    return x_trans



def batch_inb_findshare(cd_dict,x,y,d):

    classes = torch.unique(y)

    x_trans = torch.zeros_like(x).to(x.device)
    for yy in classes:
        xt = x[y==yy]
        dt = d[y==yy]
        xt_trans = cd_dict[int(yy)](xt, dt)
        x_trans[y==yy] = xt_trans
    return x_trans

class wrap_enc(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.ae_list = []
        for dd in args.list_train_domains:
            ae = AE(args)
            ae_path = args.ae_dir + '/ae' + '-' + str(dd) + '.pt'
            ae.load_state_dict(torch.load(ae_path))
            ae = ae.to(args.device)
            self.ae_list.append(ae.encoder)
            print(f'Finish loading encoder from {ae_path}')

    def forward(self, X, y):
        X = X.view(-1, 1, 28, 28)
        return self.ae_list[int(y)](X).view(X.shape[0], -1)


class wrap_dec(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.ae_list = []
        for dd in args.list_train_domains:
            ae = AE(args)
            ae_path = args.ae_dir + '/ae' + '-' + str(dd) + '.pt'
            ae.load_state_dict(torch.load(ae_path))
            ae = ae.to(args.device)
            self.ae_list.append(ae.decoder)
            print(f'Finish loading decoder from {ae_path}')

    def forward(self, X, y):
        X = X.view(-1, 8, 7, 7)
        return self.ae_list[int(y)](X).view(X.shape[0], -1)